Diagnostics example

We’ll go through some diagnostics using arviz.

Step one is to load some data. Rather than going through a whole modelling workflow, we’ll just take one of the example MCMC outputs that arviz provides via the function load_arviz_data.

This particular MCMC output has to do with measurements of soil radioactivity in the USA. You can read more about it here.

import arviz as az
import numpy as np
import xarray as xr 

idata = az.load_arviz_data("radon")
idata
arviz.InferenceData
    • <xarray.Dataset> Size: 4MB
      Dimensions:    (chain: 4, draw: 500, g_coef: 2, County: 85)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * g_coef     (g_coef) <U9 72B 'intercept' 'slope'
        * County     (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'
      Data variables:
          g          (chain, draw, g_coef) float64 32kB ...
          za_county  (chain, draw, County) float64 1MB ...
          b          (chain, draw) float64 16kB ...
          sigma_a    (chain, draw) float64 16kB ...
          a          (chain, draw, County) float64 1MB ...
          a_county   (chain, draw, County) float64 1MB ...
          sigma      (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.191355
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset> Size: 15MB
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 15MB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.449843
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 15MB
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 15MB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.448264
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 150kB
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
      Data variables:
          step_size_bar     (chain, draw) float64 16kB ...
          diverging         (chain, draw) bool 2kB ...
          energy            (chain, draw) float64 16kB ...
          tree_size         (chain, draw) float64 16kB ...
          mean_tree_accept  (chain, draw) float64 16kB ...
          step_size         (chain, draw) float64 16kB ...
          depth             (chain, draw) int64 16kB ...
          energy_error      (chain, draw) float64 16kB ...
          lp                (chain, draw) float64 16kB ...
          max_energy_error  (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.197697
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset> Size: 1MB
      Dimensions:        (chain: 1, draw: 500, County: 85, g_coef: 2)
      Coordinates:
        * chain          (chain) int64 8B 0
        * draw           (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
        * County         (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
        * g_coef         (g_coef) <U9 72B 'intercept' 'slope'
      Data variables:
          a_county       (chain, draw, County) float64 340kB ...
          sigma_log__    (chain, draw) float64 4kB ...
          sigma_a        (chain, draw) float64 4kB ...
          a              (chain, draw, County) float64 340kB ...
          b              (chain, draw) float64 4kB ...
          za_county      (chain, draw, County) float64 340kB ...
          sigma          (chain, draw) float64 4kB ...
          g              (chain, draw, g_coef) float64 8kB ...
          sigma_a_log__  (chain, draw) float64 4kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.454586
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 4MB
      Dimensions:  (chain: 1, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 4MB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.457652
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 15kB
      Dimensions:  (obs_id: 919)
      Coordinates:
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (obs_id) float64 7kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.458415
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 21kB
      Dimensions:     (obs_id: 919, County: 85)
      Coordinates:
        * obs_id      (obs_id) int64 7kB 0 1 2 3 4 5 6 ... 912 913 914 915 916 917 918
        * County      (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
      Data variables:
          floor_idx   (obs_id) int32 4kB ...
          county_idx  (obs_id) int32 4kB ...
          uranium     (County) float64 680B ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.459832
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

Arviz provides a data structure called InferenceData which it uses for storing MCMC outputs. It’s worth getting to know it: there is some helpful explanation here.

At a high level, an InferenceData is a container for several xarray Dataset objects called ‘groups’. Each group contains xarray [DataArray] (https:// docs.xarray.dev/ en/stable/generated/xarray.DataArray.html) objects called ‘variables’. Each variable contains a rectangular array of values, plus the shape of the values (‘dimensions’) and labels for the dimensions (‘coordinates’).

For example, if you click on the dropdown arrows above you will see that the group posterior contains a variable called a_county that has three dimensions called chain, draw and County. There are 85 counties and the first one is labelled 'AITKIN'.

The function az.summary lets us look at some useful summary statistics, including \(\hat{R}\), divergent transitions and MCSE.

The variable lp, which you can find in the group sample_stats is the model’s total log probability density. It’s not very meaningful on its own, but is useful for judging overall convergence. diverging counts the number of divergent transitions.

az.summary(idata.sample_stats, var_names=["lp", "diverging"])
/Users/tedgro/repos/dtu-qmcm/bayesian_statistics_for_computational_biology/.venv/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/tedgro/repos/dtu-qmcm/bayesian_statistics_for_computational_biology/.venv/lib/python3.13/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
  varsd = varvar / evar / 4
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
lp -1143.808 9.272 -1160.741 -1126.532 0.426 0.242 477.0 772.0 1.01
diverging 0.000 0.000 0.000 0.000 0.000 NaN 2000.0 2000.0 NaN

In this case there were no post-warmup diverging transitions, and the \(\hat{R}\) statistic for the lp variable is pretty close to 1: great!

Sometimes it’s useful to summarise individual parameters. This can be done by pointing az.summary at the group where the parameters of interest live. In this case the group is called posterior.

az.summary(idata.posterior, var_names=["sigma", "g"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
sigma 0.729 0.018 0.696 0.764 0.000 0.000 2228.0 1371.0 1.0
g[intercept] 1.494 0.036 1.427 1.562 0.001 0.001 1240.0 1566.0 1.0
g[slope] 0.700 0.093 0.526 0.880 0.003 0.002 1157.0 1066.0 1.0

Now we can start evaluating the model. First we check to see whether replicated measurements from the model’s posterior predictive distribution broadly agree with the observed measurements, using the arviz function plot_lm:

az.style.use("arviz-doc")
az.plot_lm(
    y=idata.observed_data["y"],
    x=idata.observed_data["obs_id"],
    y_hat=idata.posterior_predictive["y"],
    figsize=[12, 5],
    grid=False
);

The function az.loo can quickly estimate a model’s out of sample log likelihood (which we saw above is a nice default loss function), allowing a nice numerical comparison between models.

Watch out for the warning column, which can tell you if the estimation is likely to be incorrect. It’s usually a good idea to set the pointwise argument to True, as this allows for more detailed analysis at the per-observation level.

az.loo(idata, var_name="y", pointwise=True)
Computed from 2000 posterior samples and 919 observations log-likelihood matrix.

         Estimate       SE
elpd_loo -1027.18    28.85
p_loo       26.82        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)      919  100.0%
   (0.70, 1]   (bad)         0    0.0%
   (1, Inf)   (very bad)    0    0.0%

The function az.compare is useful for comparing different out of sample log likelihood estimates.

idata.log_likelihood["fake"] = xr.DataArray(
    # generate some fake log likelihoods
    np.random.normal(0, 2, [4, 500, 919]),
    coords=idata.log_likelihood.coords,
    dims=idata.log_likelihood.dims
)
comparison = az.compare(
    {
        "real": az.loo(idata, var_name="y"), 
        "fake": az.loo(idata, var_name="fake")
    }
)
comparison
/Users/tedgro/repos/dtu-qmcm/bayesian_statistics_for_computational_biology/.venv/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
rank elpd_loo p_loo elpd_diff weight se dse warning scale
real 0 -1027.176982 26.817161 0.000000 0.984786 28.848998 0.000000 False log
fake 1 -1800.021190 3632.470122 772.844209 0.015214 3.464644 29.082427 True log

The function az.plot_compare shows these results on a nice graph:

az.plot_compare(comparison);